from pytorch_lightning.utilities.types import EVAL_DATALOADERS
from torch.utils.data import DataLoader
from torchdata.datapipes.map import MapDataPipe
from pytorch_lightning import LightningDataModule
import warnings
import sys
from Executor import Executor
from utils import *
from bw_utils import *
import yaml
import json
from tarski.io import PDDLReader

def get_problem(instance, domain):
    reader = PDDLReader(raise_on_error=True)
    reader.parse_domain(domain)
    return reader.parse_instance(instance)

class PromptDataModule(LightningDataModule):
    def __init__(
        self,
        args,
        tokenizer,
        train_size=0.2,
        limit_prompts=None,
    ):
        super().__init__()
        self.save_hyperparameters(ignore="tokenizer")
        with open('data/blocksworld/bw_config.yaml', 'r') as file:
            self.data = yaml.safe_load(file)
        self.prompts = json.load(open("data/blocksworld/my_mcts_prompts_update.json", 'r'))
        with open('data/blocksworld/bw_config.yaml', 'r') as file:
            self.config = yaml.safe_load(file)
        self.domain_pddl = f'gpt-plan-benchmark/gpt_plan_test/instances/{self.config["domain_file"]}'
        self.base_prompt = "I am playing with a set of blocks where I need to arrange the blocks into stacks. Here are the actions I can do\n\nPick up a block\nUnstack a block from on top of another block\nPut down a block\nStack a block on top of another block\n\nI have the following restrictions on my actions:\nI can only pick up or unstack one block at a time.\nI can only pick up or unstack a block if my hand is empty.\nI can only pick up a block if the block is on the table and the block is clear. A block is clear if the block has no other blocks on top of it and if the block is not picked up.\nI can only unstack a block from on top of another block if the block I am unstacking was really on top of the other block.\nI can only unstack a block from on top of another block if the block I am unstacking is clear.\nOnce I pick up or unstack a block, I am holding the block.\nI can only put down a block that I am holding.\nI can only stack a block on top of another block if I am holding the block being stacked.\nI can only stack a block on top of another block if the block onto which I am stacking the block is clear.\nOnce I put down or stack a block, my hand becomes empty.\nAfter being given an initial state and an action, give the new state after performing the action.\n"
        self.tokenizer = tokenizer
        self.args = args
        self.train_data = None
        self.val_data = None

    def setup(self, stage):
        all_data = []
        train_data = json.load(open(f"/data/blocksworld/step_{self.args.step}.json", 'r'))
        train_data = train_data
        for d in train_data:
            problem = get_problem(d[0], self.domain_pddl)
            gt_plan_text = d[1]
            INIT, GOAL, PLAN = instance_to_text_blocksworld(problem, True, gt_plan_text, self.data)
            # all_data.append([INIT, GOAL, PLAN])
            # initial_state = f"I have that, {INIT}."

            # state = self.base_prompt + self.prompts["goal_prefix"] + GOAL.strip() + "\n" + self.prompts["state_prefix"].format(0) + " " + initial_state.strip() + "\n"
            all_data.append([INIT, GOAL, PLAN])
        if self.hparams.limit_prompts is not None:
            all_data = all_data[: self.hparams.limit_prompts]
        # self.hparams.train_size = 0.1
        num_train = int(len(all_data) * self.hparams.train_size)
        self.train_data = PromptDataPipe(all_data[:num_train])
        self.val_data = PromptDataPipe(all_data[:num_train])
        self.test_data = PromptDataPipe(all_data[num_train:])

        test_data = 
        test = []
        for d in test_data:
            problem = get_problem(d[0], self.domain_pddl)
            gt_plan_text = d[1]
            INIT, GOAL, PLAN = instance_to_text_blocksworld(problem, True, gt_plan_text, self.data)
            # all_data.append([INIT, GOAL, PLAN])
            # initial_state = f"I have that, {INIT}."

            # state = self.base_prompt + self.prompts["goal_prefix"] + GOAL.strip() + "\n" + self.prompts["state_prefix"].format(0) + " " + initial_state.strip() + "\n"
            test.append([INIT, GOAL, PLAN])

        self.train_data = PromptDataPipe(all_data[:15])
        self.val_data = PromptDataPipe(all_data[:15])
        self.test_data = PromptDataPipe(test[:50])

    def train_dataloader(self):
        return DataLoader(self.train_data, shuffle=True, batch_size=1)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=1)

    def test_dataloader(self):
        return DataLoader(self.test_data, batch_size=1)

class PromptDataPipe(MapDataPipe):
    def __init__(self, problems) -> None:
        super().__init__()
        self.problems = problems

    def __len__(self):
        return len(self.problems)

    def __getitem__(self, index):

        return self.problems[index]